package com.xiaojiezhu.spark.rdd.action;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;

import java.io.Serializable;
import java.util.Arrays;

/**
 * @author 朱小杰
 * 时间 2017-11-26 .20:23
 * 说明 ...
 */
public class JavaAggregate {
    /**
     * 计算平均值的一个对象
     */
    static class AvgCount implements Serializable{
        private int total;
        private int num;

        public AvgCount(int total, int num) {
            this.total = total;
            this.num = num;
        }
        public double avg(){
            return total / (double)num;
        }
    }

    /**
     * 这个函数累加值
     */
    static Function2<AvgCount,Integer,AvgCount> addAndCount = new Function2<AvgCount, Integer, AvgCount>() {
        //x 为新增的值
        @Override
        public AvgCount call(AvgCount avgCount, Integer x) throws Exception {
            avgCount.total = avgCount.total + x;
            avgCount.num++;
            return avgCount;
        }
    };
    /**
     * 这个函数进行汇总
     */
    static Function2<AvgCount,AvgCount,AvgCount> combine = new Function2<AvgCount, AvgCount, AvgCount>() {
        @Override
        public AvgCount call(AvgCount avgCount, AvgCount avgCount2) throws Exception {
            avgCount.total += avgCount2.total;
            avgCount.num += avgCount2.num;
            return avgCount;
        }
    };



    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setMaster("local").setAppName("app");
        JavaSparkContext  sc = new JavaSparkContext(conf);
        JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));

        //定义初始值
        AvgCount initValue = new AvgCount(0,0);
        //三个参数，第一个为初始值，第二个为累加的函数，第三个为各节点合并的函数
        AvgCount result = rdd.aggregate(initValue, addAndCount, combine);

        System.out.println("总和:" + result.total + " ,值个数:" + result.num + " ,平均值:" + result.avg());
    }
}
